{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Warp Core Tutorial: Generics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install warp-lang"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warp as wp\n",
    "\n",
    "wp.config.quiet = True\n",
    "\n",
    "# Explicitly initializing Warp is not necessary but\n",
    "# we do it here to ensure everything is good to go.\n",
    "wp.init()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Function Overloading\n",
    "\n",
    "Warp allows defining multiple functions with the same name that have a different parameter signature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@wp.func\n",
    "def product(\n",
    "    v: wp.vec2,\n",
    ") -> float:\n",
    "    return v[0] * v[1]\n",
    "\n",
    "\n",
    "@wp.func\n",
    "def product(\n",
    "    m: wp.mat22,\n",
    ") -> float:\n",
    "    return m[0, 0] * m[0, 1] * m[1, 0] * m[1, 1]\n",
    "\n",
    "\n",
    "# Define a kernel that computes the component-wise product\n",
    "# of a vector and a matrix.\n",
    "@wp.kernel\n",
    "def product_kernel(\n",
    "    v: wp.vec2,\n",
    "    m: wp.mat22,\n",
    "    out_product: wp.array(dtype=float),\n",
    "):\n",
    "    out_product[0] = product(v) * product(m)\n",
    "\n",
    "\n",
    "print(\"\\nproduct:\")\n",
    "v = wp.vec2(2.0, 4.0)\n",
    "m = wp.mat22(3.0, 5.0, 7.0, 9.0)\n",
    "out_product = wp.empty(1, dtype=float)\n",
    "wp.launch(product_kernel, dim=1, inputs=(v, m), outputs=(out_product,))\n",
    "print(out_product)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generic Functions\n",
    "\n",
    "A complementary approach to overloading functions is to use one of the generic types `typing.Any`, `wp.Int`, `wp.Float`, or `wp.Scalar`, and let Warp infer the final function's signature based on the arguments being passed to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This function works with integer and floating-point types of any width.\n",
    "@wp.func\n",
    "def square(x: wp.Scalar) -> wp.Scalar:\n",
    "    return x * x\n",
    "\n",
    "\n",
    "# Define two kernels that square the values of an array,\n",
    "# one for 16-bit integers, and another one for 64-bit floating-points.\n",
    "@wp.kernel\n",
    "def square_kernel_i16(arr: wp.array(dtype=wp.int16)):\n",
    "    i = wp.tid()\n",
    "    arr[i] = square(arr[i])\n",
    "\n",
    "\n",
    "@wp.kernel\n",
    "def square_kernel_f64(arr: wp.array(dtype=wp.float64)):\n",
    "    i = wp.tid()\n",
    "    arr[i] = square(arr[i])\n",
    "\n",
    "\n",
    "# First implicit kernel instantiation with a 16-bit integer type.\n",
    "print(\"\\narr_i16:\")\n",
    "arr_i16 = wp.array((1, 2, 3), dtype=wp.int16)\n",
    "wp.launch(square_kernel_i16, dim=arr_i16.shape, inputs=(arr_i16,))\n",
    "print(arr_i16)\n",
    "\n",
    "# Second implicit kernel instantiation with a 64-bit floating-point type.\n",
    "print(\"\\narr_f64:\")\n",
    "arr_f64 = wp.array((4, 5, 6), dtype=wp.float64)\n",
    "wp.launch(square_kernel_f64, dim=arr_f64.shape, inputs=(arr_f64,))\n",
    "print(arr_f64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generic Kernels\n",
    "\n",
    "The same generic types `typing.Any`, `wp.Int`, `wp.Float`, and `wp.Scalar` can also be used when annotating parameters on a kernel.\n",
    "\n",
    "To generate the final kernels from such generic types, Warp supports implicit and explicit instantiations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implicit Instantiation\n",
    "\n",
    "By default, Warp infers the final kernel's signature and implementation based on the arguments being passed to it when calling `wp.launch()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a kernel that scales the values of an array with a coefficient.\n",
    "# Its elements can be integers or floating-points of any width.\n",
    "@wp.kernel\n",
    "def scale_kernel(arr: wp.array(dtype=wp.Scalar), coeff: wp.Scalar):\n",
    "    i = wp.tid()\n",
    "    arr[i] *= coeff\n",
    "\n",
    "\n",
    "# First implicit kernel instantiation with a 16-bit integer type.\n",
    "print(\"arr_i16:\")\n",
    "arr_i16 = wp.array((1, 2, 3), dtype=wp.int16)\n",
    "wp.launch(scale_kernel, dim=arr_i16.shape, inputs=(arr_i16, wp.int16(2)))\n",
    "print(arr_i16)\n",
    "\n",
    "# Second implicit kernel instantiation with a 64-bit floating-point type.\n",
    "print(\"\\narr_f64:\")\n",
    "arr_f64 = wp.array((4, 5, 6), dtype=wp.float64)\n",
    "wp.launch(scale_kernel, dim=arr_f64.shape, inputs=(arr_f64, wp.float64(2)))\n",
    "print(arr_f64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explicit Instantiation\n",
    "\n",
    "It's also possible to specify which types a kernel should be instantiated against, before even needing to call `wp.launch()`. This is done using the `@wp.overload` decorator.\n",
    "\n",
    "One advantage of this approach is that it speeds up kernel launches since Warp won't need to try inferring and generating a new kernel instance each time. Another is related to module reloading, as detailed in the [documentation here](https://nvidia.github.io/warp/modules/generics.html#module-reloading-behavior)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a kernel that scales the values of an array with a coefficient.\n",
    "# Its elements can be integers or floating-points of any width.\n",
    "@wp.kernel\n",
    "def scale_kernel(arr: wp.array(dtype=wp.Scalar), coeff: wp.Scalar):\n",
    "    i = wp.tid()\n",
    "    arr[i] *= coeff\n",
    "\n",
    "\n",
    "# Explicit instantiation for 16-bit integers.\n",
    "@wp.overload\n",
    "def scale_kernel(arr: wp.array(dtype=wp.int16), coeff: wp.int16):\n",
    "    ...\n",
    "\n",
    "\n",
    "# Explicit instantiation for 64-bit floating-points.\n",
    "@wp.overload\n",
    "def scale_kernel(arr: wp.array(dtype=wp.float64), coeff: wp.float64):\n",
    "    ...\n",
    "\n",
    "\n",
    "# Launch the kernel instance using a 16-bit integer type.\n",
    "print(\"arr_i16:\")\n",
    "arr_i16 = wp.array((1, 2, 3), dtype=wp.int16)\n",
    "wp.launch(scale_kernel, dim=arr_i16.shape, inputs=(arr_i16, wp.int16(2)))\n",
    "print(arr_i16)\n",
    "\n",
    "# Launch the kernel instance using a 64-bit floating-point type.\n",
    "print(\"\\narr_f64:\")\n",
    "arr_f64 = wp.array((4, 5, 6), dtype=wp.float64)\n",
    "wp.launch(scale_kernel, dim=arr_f64.shape, inputs=(arr_f64, wp.float64(2)))\n",
    "print(arr_f64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Type Introspection\n",
    "\n",
    "Due to Warp's strict typing rules and lack of integer/floating-point promotion rules, it is required to pass the exact argument types when calling functions. For example, when constructing a `wp.vec3s()` instance, it is necessary to ensure that each argument is explicitly casted to the type `wp.int16`, if it isn't of that type already, like `wp.vec3s(wp.int16(1), wp.int16(2), wp.int16(3))`, since integer literals default to 32-bit.\n",
    "\n",
    "In the context of a generic kernel/function where the parameter type is only known at runtime, Warp exposes a `type()` operator that allows retrieving the resolved type of a variable in order to initialize/cast values.\n",
    "\n",
    "To retrieve the data type of the elements of an array, calling `type()` on the first element can be used, but a more convenient form is also available with `array.dtype`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a kernel that increases the values of an array by a fixed amount.\n",
    "@wp.kernel\n",
    "def increase_kernel(arr: wp.array(dtype=wp.Scalar)):\n",
    "    i = wp.tid()\n",
    "\n",
    "    # These 2 calls are equivalent.\n",
    "    arr[i] += type(arr[0])(2)\n",
    "    arr[i] += arr.dtype(3)\n",
    "\n",
    "\n",
    "# Launch the kernel instance using a 16-bit integer type.\n",
    "print(\"arr_i16:\")\n",
    "arr_i16 = wp.array((1, 2, 3), dtype=wp.int16)\n",
    "wp.launch(increase_kernel, dim=arr_i16.shape, inputs=(arr_i16,))\n",
    "print(arr_i16)\n",
    "\n",
    "# Launch the kernel instance using a 64-bit floating-point type.\n",
    "print(\"\\narr_f64:\")\n",
    "arr_f64 = wp.array((4, 5, 6), dtype=wp.float64)\n",
    "wp.launch(increase_kernel, dim=arr_f64.shape, inputs=(arr_f64,))\n",
    "print(arr_f64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dynamic Code Generation\n",
    "\n",
    "When more flexibility is desired than what the approaches covered so far can offer, we can make use of the dynamic nature of Python to generate kernels, functions, and even structs at runtime using closures that define values, types, or even functions as parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define some operator functions that we can pass to the kernel as arguments.\n",
    "\n",
    "\n",
    "@wp.func\n",
    "def op_add(a: wp.Scalar, b: wp.Scalar) -> wp.Scalar:\n",
    "    return a + b\n",
    "\n",
    "\n",
    "@wp.func\n",
    "def op_mul(a: wp.Scalar, b: wp.Scalar) -> wp.Scalar:\n",
    "    return a * b\n",
    "\n",
    "\n",
    "# Closure creating and returning a kernel.\n",
    "# All the argument values will be embedded into the generated code\n",
    "# that is to be compiled against the target architecture (CUDA or C++).\n",
    "def create_kernel(vec_length: int, vec_dtype: wp.Scalar, num_iter: int, op_fn: wp.Function) -> wp.kernel:\n",
    "    # Define the vector type from its length/dtype.\n",
    "    vec = wp.vec(vec_length, vec_dtype)\n",
    "\n",
    "    # Define a function that reduces all of a vector's components into a single\n",
    "    # value, using the provided operator function.\n",
    "    @wp.func\n",
    "    def reduce(v: vec) -> vec_dtype:\n",
    "        out = vec_dtype(0)\n",
    "        for i in range(vec_length):\n",
    "            out += op_fn(v[i], vec_dtype(i))\n",
    "\n",
    "        return out\n",
    "\n",
    "    # Define the kernel function to return.\n",
    "    @wp.kernel\n",
    "    def kernel(arr: wp.array(dtype=vec)):\n",
    "        tid = wp.tid()\n",
    "\n",
    "        v = vec()\n",
    "        for i in range(vec_length):\n",
    "            v[i] = vec_dtype(tid + i)\n",
    "\n",
    "        for _ in range(num_iter):\n",
    "            v *= reduce(v)\n",
    "\n",
    "        arr[tid] = v\n",
    "\n",
    "    return kernel\n",
    "\n",
    "\n",
    "# Generate and evaluate a first kernel.\n",
    "print(\"arr_1:\")\n",
    "vec_length = 2\n",
    "vec_dtype = wp.int32\n",
    "num_iter = 3\n",
    "op_fn = op_mul\n",
    "arr_1 = wp.empty(3, dtype=wp.vec(vec_length, vec_dtype))\n",
    "kernel_1 = create_kernel(vec_length, vec_dtype, num_iter, op_fn)\n",
    "wp.launch(kernel_1, dim=arr_1.shape, inputs=(arr_1,))\n",
    "print(arr_1)\n",
    "\n",
    "# Generate and evaluate a second kernel.\n",
    "print(\"\\narr_2:\")\n",
    "vec_length = 3\n",
    "vec_dtype = wp.float64\n",
    "num_iter = 2\n",
    "op_fn = op_add\n",
    "arr_2 = wp.empty(3, dtype=wp.vec(vec_length, vec_dtype))\n",
    "kernel_2 = create_kernel(vec_length, vec_dtype, num_iter, op_fn)\n",
    "wp.launch(kernel_2, dim=arr_2.shape, inputs=(arr_2,))\n",
    "print(arr_2)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
      "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
      "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
